Skip to content

Conversation

@fschlimb
Copy link
Contributor

  • lowering shard.allgather to mpi.allgather
  • fixing lowering of shard.allreduce
  • minor refactoring

@llvmbot
Copy link
Member

llvmbot commented Jan 21, 2026

@llvm/pr-subscribers-mlir-linalg

Author: Frank Schlimbach (fschlimb)

Changes
  • lowering shard.allgather to mpi.allgather
  • fixing lowering of shard.allreduce
  • minor refactoring

Patch is 21.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/177202.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Shard/IR/ShardOps.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h (+6-6)
  • (modified) mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp (+115-70)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Shard/Transforms/Transforms.cpp (+9-8)
  • (modified) mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir (+43-12)
diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 5e68f75ee08bf..6ef7c72d305ee 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -530,11 +530,11 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
     ```
   }];
   let arguments = !con(commonArgs, (ins
-    AnyNon0RankedTensor:$input,
+    AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
     IndexAttr:$gather_axis
   ));
   let results = (outs
-    AnyNon0RankedTensor:$result
+    AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
   );
   let assemblyFormat = [{
     $input `on` $grid (`grid_axes` `=` $grid_axes^)? `gather_axis` `=` $gather_axis
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
index 57d65e687ea35..1ddd1985389bc 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
@@ -39,14 +39,14 @@ createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
                                  ImplicitLocOpBuilder &builder);
 
 // Get process linear index along the given grid axes.
-TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
-                                               ArrayRef<GridAxis> gridAxes,
-                                               ImplicitLocOpBuilder &builder);
+TypedValue<IndexType>
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+                         ArrayRef<GridAxis> gridAxes = {});
 // Get process linear index from a multi-index along the given grid axes .
 TypedValue<IndexType>
-createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
-                         ArrayRef<GridAxis> gridAxes,
-                         ImplicitLocOpBuilder &builder);
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+                         ValueRange processInGroupMultiIndex,
+                         ArrayRef<GridAxis> gridAxes = {});
 
 } // namespace shard
 } // namespace mlir
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index b0831dc05abb1..1865914de9d84 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Shard/IR/ShardDialect.h"
 #include "mlir/Dialect/Shard/IR/ShardOps.h"
@@ -507,103 +508,147 @@ static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
   }
 }
 
-struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
-  using OpConversionPattern::OpConversionPattern;
+template <typename CommOp>
+struct CommOpPattern : public OpConversionPattern<CommOp> {
+  using OpConversionPattern<CommOp>::OpConversionPattern;
 
-  LogicalResult
-  matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    SymbolTableCollection symbolTableCollection;
-    auto grid = adaptor.getGrid();
-    mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection);
-    if (!gridOp)
-      return op->emitError() << "No grid found for AllReduceOp";
-    if (ShapedType::isDynamicShape(gridOp.getShape()))
-      return op->emitError()
-             << "Dynamic grid shape not supported in AllReduceOp";
-
-    ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
-    Value input = adaptor.getInput();
-    auto inputShape = cast<ShapedType>(input.getType()).getShape();
+  MemRefType getMemrefType(ShapedType tensorType) const {
+    return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+  }
 
+  Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder) const {
+    auto itype = input.getType();
     // If the source is a memref, cast it to a tensor.
-    if (isa<RankedTensorType>(input.getType())) {
-      auto memrefType = MemRefType::get(
-          inputShape, cast<ShapedType>(input.getType()).getElementType());
+    if (isa<RankedTensorType>(itype)) {
+      auto memrefType = getMemrefType(cast<ShapedType>(itype));
       input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
+    } else {
+      assert(isa<MemRefType>(itype) &&
+             "expected input to be of MemRefType or TensorType");
     }
-    MemRefType inType = cast<MemRefType>(input.getType());
-
-    // Get the actual shape to allocate the buffer.
-    SmallVector<OpFoldResult> shape(inType.getRank());
-    for (auto i = 0; i < inType.getRank(); ++i) {
-      auto s = inputShape[i];
-      if (ShapedType::isDynamic(s))
-        shape[i] = memref::DimOp::create(iBuilder, input, s).getResult();
-      else
-        shape[i] = iBuilder.getIndexAttr(s);
-    }
+    return input;
+  }
 
-    // Allocate buffer and copy input to buffer.
-    Value buffer = memref::AllocOp::create(
-        iBuilder, shape, cast<ShapedType>(op.getType()).getElementType());
-    linalg::CopyOp::create(iBuilder, input, buffer);
+  FailureOr<GridOp> checkGrid(CommOp op,
+                              SymbolTableCollection &symbolTableCollection,
+                              bool allowDynamic = false) const {
+    GridOp gridOp = getGrid(op, symbolTableCollection);
+    if (!gridOp)
+      return op->emitError() << "Missing grid symbol.";
+    if (!allowDynamic && ShapedType::isDynamicShape(gridOp.getShape()))
+      return op->emitError() << "Dynamic grid shape not supported.";
+    return gridOp;
+  }
 
-    // Get an MPI_Comm_split for the AllReduce operation.
+  Value getComm(GridOp &gridOp, ::llvm::ArrayRef<int16_t> gridAxes,
+                ImplicitLocOpBuilder &iBuilder) const {
+    // Get an MPI_Comm_split for a given grid and axes.
     // The color is the linear index of the process in the grid along the
-    // non-reduced axes. The key is the linear index of the process in the grid
-    // along the reduced axes.
-    SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
-                                       iBuilder.getIndexType());
-    SmallVector<Value> myMultiIndex =
-        ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid)
-            .getResult();
-    Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
-    SmallVector<Value> multiKey(myMultiIndex.size(), zero);
-
-    auto redAxes = adaptor.getGridAxes();
-    for (auto axis : redAxes) {
-      multiKey[axis] = myMultiIndex[axis];
-      myMultiIndex[axis] = zero;
+    // non-'grid-axes'. The key is the linear index of the process in the grid
+    // along the grid-axes.
+    SmallVector<GridAxis> otherAxes;
+    for (GridAxis i = 0; i < static_cast<GridAxis>(gridOp.getShape().size());
+         ++i) {
+      if (!llvm::is_contained(gridAxes, i))
+        otherAxes.emplace_back(i);
     }
 
+    SmallVector<Type> indexResultTypes(otherAxes.size(),
+                                       iBuilder.getIndexType());
+
     Value color =
-        createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder);
+        createProcessLinearIndex(iBuilder, gridOp.getSymName(), otherAxes);
     color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
-    Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder);
+
+    Value key =
+        createProcessLinearIndex(iBuilder, gridOp.getSymName(), gridAxes);
     key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
 
     // Finally split the communicator
-    auto commType = mpi::CommType::get(op->getContext());
+    auto commType = mpi::CommType::get(gridOp->getContext());
     Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
-    auto comm =
-        mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
-            .getNewcomm();
-
-    Value buffer1d = buffer;
-    // Collapse shape to 1d if needed
-    if (inType.getRank() > 1) {
-      ReassociationIndices reassociation(inType.getRank());
-      std::iota(reassociation.begin(), reassociation.end(), 0);
-      buffer1d = memref::CollapseShapeOp::create(
-          iBuilder, buffer, ArrayRef<ReassociationIndices>(reassociation));
-    }
+    return mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
+        .getNewcomm();
+  }
+};
 
+struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
+  using CommOpPattern::CommOpPattern;
+
+  LogicalResult
+  matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SymbolTableCollection symbolTableCollection;
+    FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
+    if (failed(gridOp))
+      return failure();
+    ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+    Value input = getAsMemref(adaptor.getInput(), iBuilder);
+    MemRefType inType = cast<MemRefType>(input.getType());
+    if (!memref::isStaticShapeAndContiguousRowMajor(inType))
+      return op.emitError(
+          "Expected static shaped memref in contiguous row-major layout.");
+    MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+    if (!memref::isStaticShapeAndContiguousRowMajor(outType))
+      return op.emitError(
+          "Expected static shaped memref in contiguous row-major layout.");
+
+    // Allocate buffer and copy input to buffer.
+    Value buffer = memref::AllocOp::create(iBuilder, outType);
+    linalg::CopyOp::create(iBuilder, input, buffer);
+    // Get the right communicator
+    Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
     // Create the MPI AllReduce operation.
-    mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d,
+    mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer, buffer,
                              getMPIReductionOp(adaptor.getReductionAttr()),
                              comm);
 
-    // If the destination is a memref, cast it to a tensor
+    // If the destination is a tensor, cast it to a tensor
     if (isa<RankedTensorType>(op.getType()))
       buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
                                                  true);
-
     rewriter.replaceOp(op, buffer);
     return success();
   }
 };
 
+struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
+  using CommOpPattern::CommOpPattern;
+
+  LogicalResult
+  matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SymbolTableCollection symbolTableCollection;
+    FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
+    if (failed(gridOp))
+      return failure();
+    ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+    Value input = getAsMemref(adaptor.getInput(), iBuilder);
+    MemRefType inType = cast<MemRefType>(input.getType());
+    if (!memref::isStaticShapeAndContiguousRowMajor(inType))
+      return op.emitError(
+          "Expected static shaped memref in contiguous row-major layout.");
+    MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+    if (!memref::isStaticShapeAndContiguousRowMajor(outType))
+      return op.emitError(
+          "Expected static shaped memref in contiguous row-major layout.");
+
+    // Get the right communicator
+    Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
+    // Allocate output buffer
+    Value output = memref::AllocOp::create(iBuilder, outType);
+    // Create the MPI AllGather operation.
+    mpi::AllGatherOp::create(iBuilder, TypeRange(), input, output, comm);
+
+    // If the destination is a tensor, cast it to a tensor
+    if (isa<RankedTensorType>(op.getType()))
+      output = bufferization::ToTensorOp::create(iBuilder, op.getType(), output,
+                                                 true);
+    rewriter.replaceOp(op, output);
+    return success();
+  }
+};
+
 struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -895,8 +940,8 @@ struct ConvertShardToMPIPass
 
     patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
                  ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
-                 ConvertAllReduceOp, ConvertProcessLinearIndexOp>(typeConverter,
-                                                                  ctxt);
+                 ConvertAllGatherOp, ConvertAllReduceOp,
+                 ConvertProcessLinearIndexOp>(typeConverter, ctxt);
     SymbolTableCollection stc;
     populateProcessMultiIndexOpLoweringPatterns(patterns, stc);
     populateAllSliceOpLoweringPatterns(patterns, stc);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
index 0ae2a9cc0318c..d0165595f9fb6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
@@ -128,7 +128,7 @@ static Value createDestinationPassingStyleInitOperand(
     ArrayRef<GridAxis> reductionGridAxes, GridOp gridOp,
     ImplicitLocOpBuilder &builder) {
   Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex(
-      gridOp.getSymName(), reductionGridAxes, builder);
+      builder, gridOp.getSymName(), reductionGridAxes);
   Value zero = arith::ConstantIndexOp::create(builder, 0);
   Value isLeadProcess = arith::CmpIOp::create(
       builder, builder.getI1Type(), arith::CmpIPredicate::eq,
diff --git a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
index b433b8b0be7b2..835bc443d4b2a 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
@@ -208,9 +208,9 @@ createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
 }
 
 TypedValue<IndexType>
-createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
-                         ArrayRef<GridAxis> gridAxes,
-                         ImplicitLocOpBuilder &builder) {
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+                         ValueRange processInGroupMultiIndex,
+                         ArrayRef<GridAxis> gridAxes) {
   Operation::result_range processGroupShape =
       GridShapeOp::create(builder, grid, gridAxes).getResult();
   OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
@@ -224,11 +224,12 @@ createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
   return cast<TypedValue<IndexType>>(res);
 }
 
-TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
-                                               ArrayRef<GridAxis> gridAxes,
-                                               ImplicitLocOpBuilder &builder) {
+TypedValue<IndexType> createProcessLinearIndex(ImplicitLocOpBuilder &builder,
+                                               StringRef grid,
+                                               ArrayRef<GridAxis> gridAxes) {
   return createProcessLinearIndex(
-      grid, ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
-      gridAxes, builder);
+      builder, grid,
+      ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
+      gridAxes);
 }
 } // namespace mlir::shard
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index a0b6bfaf6fd3d..d4741102e4d3f 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -102,15 +102,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
   func.func @allreduce_tensor(
     // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
     %arg0 : tensor<3x4xf32>) -> tensor<3x4xf32> {
-    // CHECK-DAG: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
     // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
     // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
     // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
     // CHECK: linalg.copy ins([[v0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
     // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
-    // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
-    // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
+    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+    // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf32>, memref<3x4xf32>
     // CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x4xf32> to tensor<3x4xf32>
     %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32>
     // CHECK: return [[v2]] : tensor<3x4xf32>
@@ -121,14 +120,13 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
   func.func @allreduce_memref(
     // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
     %arg0 : memref<3x4xf32>) -> memref<3x4xf32> {
-    // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK: [[vc1_i32:%.*]] = arith.constant 1 : i32
     // CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
     // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
     // CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
     // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
-    // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
-    // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
+    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+    // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf32>, memref<3x4xf32>
     %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32>
     // CHECK: return [[valloc]] : memref<3x4xf32>
     return %0 : memref<3x4xf32>
@@ -138,18 +136,51 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
   func.func @allreduce_new_type(
     // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
     %arg0 : memref<3x4xf32>) -> memref<3x4xf64> {
-    // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK: [[vc1_i32:%.*]] = arith.constant 1 : i32
     // CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
     // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf64>
     // CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf64>)
     // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
-    // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64>
-    // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64>
+    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+    // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf64>, memref<3x4xf64>
     %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64>
     // CHECK: return [[valloc]] : memref<3x4xf64>
     return %0 : memref<3x4xf64>
   }
+
+  // CHECK-LABEL: func @allgather_tensor
+  func.func @allgather_tensor(
+      // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
+      // CHECK-SAME: -> tensor<3x20xf32>
+      %arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> {
+    // CHECK: [[vc2_i32:%.*]]...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 21, 2026

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

Changes
  • lowering shard.allgather to mpi.allgather
  • fixing lowering of shard.allreduce
  • minor refactoring

Patch is 21.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/177202.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Shard/IR/ShardOps.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h (+6-6)
  • (modified) mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp (+115-70)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Shard/Transforms/Transforms.cpp (+9-8)
  • (modified) mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir (+43-12)
diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 5e68f75ee08bf..6ef7c72d305ee 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -530,11 +530,11 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
     ```
   }];
   let arguments = !con(commonArgs, (ins
-    AnyNon0RankedTensor:$input,
+    AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
     IndexAttr:$gather_axis
   ));
   let results = (outs
-    AnyNon0RankedTensor:$result
+    AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
   );
   let assemblyFormat = [{
     $input `on` $grid (`grid_axes` `=` $grid_axes^)? `gather_axis` `=` $gather_axis
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
index 57d65e687ea35..1ddd1985389bc 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
@@ -39,14 +39,14 @@ createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
                                  ImplicitLocOpBuilder &builder);
 
 // Get process linear index along the given grid axes.
-TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
-                                               ArrayRef<GridAxis> gridAxes,
-                                               ImplicitLocOpBuilder &builder);
+TypedValue<IndexType>
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+                         ArrayRef<GridAxis> gridAxes = {});
 // Get process linear index from a multi-index along the given grid axes .
 TypedValue<IndexType>
-createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
-                         ArrayRef<GridAxis> gridAxes,
-                         ImplicitLocOpBuilder &builder);
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+                         ValueRange processInGroupMultiIndex,
+                         ArrayRef<GridAxis> gridAxes = {});
 
 } // namespace shard
 } // namespace mlir
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index b0831dc05abb1..1865914de9d84 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Shard/IR/ShardDialect.h"
 #include "mlir/Dialect/Shard/IR/ShardOps.h"
@@ -507,103 +508,147 @@ static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
   }
 }
 
-struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
-  using OpConversionPattern::OpConversionPattern;
+template <typename CommOp>
+struct CommOpPattern : public OpConversionPattern<CommOp> {
+  using OpConversionPattern<CommOp>::OpConversionPattern;
 
-  LogicalResult
-  matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    SymbolTableCollection symbolTableCollection;
-    auto grid = adaptor.getGrid();
-    mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection);
-    if (!gridOp)
-      return op->emitError() << "No grid found for AllReduceOp";
-    if (ShapedType::isDynamicShape(gridOp.getShape()))
-      return op->emitError()
-             << "Dynamic grid shape not supported in AllReduceOp";
-
-    ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
-    Value input = adaptor.getInput();
-    auto inputShape = cast<ShapedType>(input.getType()).getShape();
+  MemRefType getMemrefType(ShapedType tensorType) const {
+    return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+  }
 
+  Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder) const {
+    auto itype = input.getType();
     // If the source is a memref, cast it to a tensor.
-    if (isa<RankedTensorType>(input.getType())) {
-      auto memrefType = MemRefType::get(
-          inputShape, cast<ShapedType>(input.getType()).getElementType());
+    if (isa<RankedTensorType>(itype)) {
+      auto memrefType = getMemrefType(cast<ShapedType>(itype));
       input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
+    } else {
+      assert(isa<MemRefType>(itype) &&
+             "expected input to be of MemRefType or TensorType");
     }
-    MemRefType inType = cast<MemRefType>(input.getType());
-
-    // Get the actual shape to allocate the buffer.
-    SmallVector<OpFoldResult> shape(inType.getRank());
-    for (auto i = 0; i < inType.getRank(); ++i) {
-      auto s = inputShape[i];
-      if (ShapedType::isDynamic(s))
-        shape[i] = memref::DimOp::create(iBuilder, input, s).getResult();
-      else
-        shape[i] = iBuilder.getIndexAttr(s);
-    }
+    return input;
+  }
 
-    // Allocate buffer and copy input to buffer.
-    Value buffer = memref::AllocOp::create(
-        iBuilder, shape, cast<ShapedType>(op.getType()).getElementType());
-    linalg::CopyOp::create(iBuilder, input, buffer);
+  FailureOr<GridOp> checkGrid(CommOp op,
+                              SymbolTableCollection &symbolTableCollection,
+                              bool allowDynamic = false) const {
+    GridOp gridOp = getGrid(op, symbolTableCollection);
+    if (!gridOp)
+      return op->emitError() << "Missing grid symbol.";
+    if (!allowDynamic && ShapedType::isDynamicShape(gridOp.getShape()))
+      return op->emitError() << "Dynamic grid shape not supported.";
+    return gridOp;
+  }
 
-    // Get an MPI_Comm_split for the AllReduce operation.
+  Value getComm(GridOp &gridOp, ::llvm::ArrayRef<int16_t> gridAxes,
+                ImplicitLocOpBuilder &iBuilder) const {
+    // Get an MPI_Comm_split for a given grid and axes.
     // The color is the linear index of the process in the grid along the
-    // non-reduced axes. The key is the linear index of the process in the grid
-    // along the reduced axes.
-    SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
-                                       iBuilder.getIndexType());
-    SmallVector<Value> myMultiIndex =
-        ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid)
-            .getResult();
-    Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
-    SmallVector<Value> multiKey(myMultiIndex.size(), zero);
-
-    auto redAxes = adaptor.getGridAxes();
-    for (auto axis : redAxes) {
-      multiKey[axis] = myMultiIndex[axis];
-      myMultiIndex[axis] = zero;
+    // non-'grid-axes'. The key is the linear index of the process in the grid
+    // along the grid-axes.
+    SmallVector<GridAxis> otherAxes;
+    for (GridAxis i = 0; i < static_cast<GridAxis>(gridOp.getShape().size());
+         ++i) {
+      if (!llvm::is_contained(gridAxes, i))
+        otherAxes.emplace_back(i);
     }
 
+    SmallVector<Type> indexResultTypes(otherAxes.size(),
+                                       iBuilder.getIndexType());
+
     Value color =
-        createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder);
+        createProcessLinearIndex(iBuilder, gridOp.getSymName(), otherAxes);
     color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
-    Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder);
+
+    Value key =
+        createProcessLinearIndex(iBuilder, gridOp.getSymName(), gridAxes);
     key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
 
     // Finally split the communicator
-    auto commType = mpi::CommType::get(op->getContext());
+    auto commType = mpi::CommType::get(gridOp->getContext());
     Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
-    auto comm =
-        mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
-            .getNewcomm();
-
-    Value buffer1d = buffer;
-    // Collapse shape to 1d if needed
-    if (inType.getRank() > 1) {
-      ReassociationIndices reassociation(inType.getRank());
-      std::iota(reassociation.begin(), reassociation.end(), 0);
-      buffer1d = memref::CollapseShapeOp::create(
-          iBuilder, buffer, ArrayRef<ReassociationIndices>(reassociation));
-    }
+    return mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
+        .getNewcomm();
+  }
+};
 
+struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
+  using CommOpPattern::CommOpPattern;
+
+  LogicalResult
+  matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SymbolTableCollection symbolTableCollection;
+    FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
+    if (failed(gridOp))
+      return failure();
+    ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+    Value input = getAsMemref(adaptor.getInput(), iBuilder);
+    MemRefType inType = cast<MemRefType>(input.getType());
+    if (!memref::isStaticShapeAndContiguousRowMajor(inType))
+      return op.emitError(
+          "Expected static shaped memref in contiguous row-major layout.");
+    MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+    if (!memref::isStaticShapeAndContiguousRowMajor(outType))
+      return op.emitError(
+          "Expected static shaped memref in contiguous row-major layout.");
+
+    // Allocate buffer and copy input to buffer.
+    Value buffer = memref::AllocOp::create(iBuilder, outType);
+    linalg::CopyOp::create(iBuilder, input, buffer);
+    // Get the right communicator
+    Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
     // Create the MPI AllReduce operation.
-    mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d,
+    mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer, buffer,
                              getMPIReductionOp(adaptor.getReductionAttr()),
                              comm);
 
-    // If the destination is a memref, cast it to a tensor
+    // If the destination is a tensor, cast it to a tensor
     if (isa<RankedTensorType>(op.getType()))
       buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
                                                  true);
-
     rewriter.replaceOp(op, buffer);
     return success();
   }
 };
 
+struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
+  using CommOpPattern::CommOpPattern;
+
+  LogicalResult
+  matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SymbolTableCollection symbolTableCollection;
+    FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
+    if (failed(gridOp))
+      return failure();
+    ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+    Value input = getAsMemref(adaptor.getInput(), iBuilder);
+    MemRefType inType = cast<MemRefType>(input.getType());
+    if (!memref::isStaticShapeAndContiguousRowMajor(inType))
+      return op.emitError(
+          "Expected static shaped memref in contiguous row-major layout.");
+    MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+    if (!memref::isStaticShapeAndContiguousRowMajor(outType))
+      return op.emitError(
+          "Expected static shaped memref in contiguous row-major layout.");
+
+    // Get the right communicator
+    Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
+    // Allocate output buffer
+    Value output = memref::AllocOp::create(iBuilder, outType);
+    // Create the MPI AllGather operation.
+    mpi::AllGatherOp::create(iBuilder, TypeRange(), input, output, comm);
+
+    // If the destination is a tensor, cast it to a tensor
+    if (isa<RankedTensorType>(op.getType()))
+      output = bufferization::ToTensorOp::create(iBuilder, op.getType(), output,
+                                                 true);
+    rewriter.replaceOp(op, output);
+    return success();
+  }
+};
+
 struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -895,8 +940,8 @@ struct ConvertShardToMPIPass
 
     patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
                  ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
-                 ConvertAllReduceOp, ConvertProcessLinearIndexOp>(typeConverter,
-                                                                  ctxt);
+                 ConvertAllGatherOp, ConvertAllReduceOp,
+                 ConvertProcessLinearIndexOp>(typeConverter, ctxt);
     SymbolTableCollection stc;
     populateProcessMultiIndexOpLoweringPatterns(patterns, stc);
     populateAllSliceOpLoweringPatterns(patterns, stc);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
index 0ae2a9cc0318c..d0165595f9fb6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
@@ -128,7 +128,7 @@ static Value createDestinationPassingStyleInitOperand(
     ArrayRef<GridAxis> reductionGridAxes, GridOp gridOp,
     ImplicitLocOpBuilder &builder) {
   Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex(
-      gridOp.getSymName(), reductionGridAxes, builder);
+      builder, gridOp.getSymName(), reductionGridAxes);
   Value zero = arith::ConstantIndexOp::create(builder, 0);
   Value isLeadProcess = arith::CmpIOp::create(
       builder, builder.getI1Type(), arith::CmpIPredicate::eq,
diff --git a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
index b433b8b0be7b2..835bc443d4b2a 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
@@ -208,9 +208,9 @@ createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
 }
 
 TypedValue<IndexType>
-createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
-                         ArrayRef<GridAxis> gridAxes,
-                         ImplicitLocOpBuilder &builder) {
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+                         ValueRange processInGroupMultiIndex,
+                         ArrayRef<GridAxis> gridAxes) {
   Operation::result_range processGroupShape =
       GridShapeOp::create(builder, grid, gridAxes).getResult();
   OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
@@ -224,11 +224,12 @@ createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
   return cast<TypedValue<IndexType>>(res);
 }
 
-TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
-                                               ArrayRef<GridAxis> gridAxes,
-                                               ImplicitLocOpBuilder &builder) {
+TypedValue<IndexType> createProcessLinearIndex(ImplicitLocOpBuilder &builder,
+                                               StringRef grid,
+                                               ArrayRef<GridAxis> gridAxes) {
   return createProcessLinearIndex(
-      grid, ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
-      gridAxes, builder);
+      builder, grid,
+      ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
+      gridAxes);
 }
 } // namespace mlir::shard
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index a0b6bfaf6fd3d..d4741102e4d3f 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -102,15 +102,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
   func.func @allreduce_tensor(
     // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
     %arg0 : tensor<3x4xf32>) -> tensor<3x4xf32> {
-    // CHECK-DAG: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
     // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
     // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
     // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
     // CHECK: linalg.copy ins([[v0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
     // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
-    // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
-    // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
+    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+    // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf32>, memref<3x4xf32>
     // CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x4xf32> to tensor<3x4xf32>
     %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32>
     // CHECK: return [[v2]] : tensor<3x4xf32>
@@ -121,14 +120,13 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
   func.func @allreduce_memref(
     // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
     %arg0 : memref<3x4xf32>) -> memref<3x4xf32> {
-    // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK: [[vc1_i32:%.*]] = arith.constant 1 : i32
     // CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
     // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
     // CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
     // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
-    // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
-    // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
+    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+    // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf32>, memref<3x4xf32>
     %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32>
     // CHECK: return [[valloc]] : memref<3x4xf32>
     return %0 : memref<3x4xf32>
@@ -138,18 +136,51 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
   func.func @allreduce_new_type(
     // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
     %arg0 : memref<3x4xf32>) -> memref<3x4xf64> {
-    // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK: [[vc1_i32:%.*]] = arith.constant 1 : i32
     // CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
     // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf64>
     // CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf64>)
     // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
-    // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64>
-    // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64>
+    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+    // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf64>, memref<3x4xf64>
     %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64>
     // CHECK: return [[valloc]] : memref<3x4xf64>
     return %0 : memref<3x4xf64>
   }
+
+  // CHECK-LABEL: func @allgather_tensor
+  func.func @allgather_tensor(
+      // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
+      // CHECK-SAME: -> tensor<3x20xf32>
+      %arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> {
+    // CHECK: [[vc2_i32:%.*]]...
[truncated]

Copy link
Contributor

@tkarna tkarna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good on high level. Going forward we should ensure that bufferization works as expected.

// CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x20xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving forward we should have a mechanism to deallocate the output buffer if safe to do so.

%arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> {
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could add a read_only attribute if the access pattern is always the same.

// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x20xf32>
// CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<3x20xf32>
// CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x20xf32> to tensor<3x20xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might need writable attribute to avoid copies later on.

@fschlimb fschlimb merged commit b9ab888 into llvm:main Jan 23, 2026
11 checks passed
Harrish92 pushed a commit to Harrish92/llvm-project that referenced this pull request Jan 23, 2026
- lowering `shard.allgather` to `mpi.allgather`
- fixing lowering of `shard.allreduce`
- minor refactoring
amd-eochoalo added a commit to iree-org/iree that referenced this pull request Jan 23, 2026
Reverts carried forward:
* Local revert of llvm/llvm-project#169614 due
to #22649

Other changes:
* cast shard.all_gather result type to `ShapedType` due to its result
changing types to be `AnyTypeOf<[AnyMemRef, AnyRankedTensor]>` in
llvm/llvm-project#177202
Harrish92 pushed a commit to Harrish92/llvm-project that referenced this pull request Jan 24, 2026
- lowering `shard.allgather` to `mpi.allgather`
- fixing lowering of `shard.allreduce`
- minor refactoring
keshavvinayak01 pushed a commit to iree-org/iree that referenced this pull request Jan 27, 2026
Reverts carried forward:
* Local revert of llvm/llvm-project#169614 due
to #22649

Other changes:
* cast shard.all_gather result type to `ShapedType` due to its result
changing types to be `AnyTypeOf<[AnyMemRef, AnyRankedTensor]>` in
llvm/llvm-project#177202

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Jan 29, 2026
- lowering `shard.allgather` to `mpi.allgather`
- fixing lowering of `shard.allreduce`
- minor refactoring
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants